import torch
from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver


A = [ 1/18, 1/12, 1/8, 5/16, 3/8, 59/400, 93/200, 5490023248/9719169821, 13/20, 1201146811/1299019798, 1, 1, 1]

B = [ 
    [1/18], 
    
    [1/48, 1/16],
   
    [1/32, 0, 3/32],
    
    [5/16, 0, -75/64, 75/64], 
    
    [3/80, 0, 0, 3/16, 3/20],
    
    [29443841/614563906, 0, 0, 77736538/692538347, -28693883/1125000000, 23124283/1800000000],
    
    [16016141/946692911, 0, 0, 61564180/158732637, 22789713/633445777, 545815736/2771057229, -180193667/1043307555],
    
    [39632708/573591083, 0, 0, -433636366/683701615, -421739975/2616292301, 100302831/723423059, 790204164/839813087, 800635310/3783071287],
    
    [246121993/1340847787, 0, 0, -37695042795/15268766246, -309121744/1061227803, -12992083/490766935, 6005943493/2108947869, 393006217/1396673457, 123872331/1001029789],
    
    [-1028468189/846180014, 0, 0, 8478235783/508512852, 1311729495/1432422823, -10304129995/1701304382, -48777925059/3047939560, 15336726248/1032824649, -45442868181/3398467696, 3065993473/597172653],
    
    [185892177/718116043, 0, 0, -3185094517/667107341, -477755414/1098053517, -703635378/230739211, 5731566787/1027545527, 5232866602/850066563, -4093664535/808688257, 3962137247/1805957418, 65686358/487910083],
    
    [403863854/491063109, 0, 0, -5068492393/434740067, -411421997/543043805, 652783627/914296604, 11173962825/925320556, -13158990841/6184727034, 3936647629/1978049680, -160528059/685178525, 248638103/1413531060, 0],
    
    [ 14005451/335480064, 0, 0, 0, 0, -59238493/1068277825, 181606767/758867731,   561292985/797845732,   -1041891430/1371343529,  760417239/1151165299, 118820643/751138087, -528747749/2220607170,  1/4]
]

C_sol = [ 14005451/335480064, 0, 0, 0, 0, -59238493/1068277825, 181606767/758867731,   561292985/797845732,   -1041891430/1371343529,  760417239/1151165299, 118820643/751138087, -528747749/2220607170,  1/4, 0]

C_err = [ 14005451/335480064 - 13451932/455176623, 0, 0, 0, 0, -59238493/1068277825 - -808719846/976000145, 181606767/758867731 - 1757004468/5645159321,   561292985/797845732 - 656045339/265891186,   -1041891430/1371343529 - -3867574721/1518517206,  760417239/1151165299 - 465885868/322736535, 118820643/751138087 - 53011238/667516719, -528747749/2220607170 - 2/45,  1/4, 0]

h = 1/2

C_mid = [0.] * 14

C_mid[0] = (- 6.3448349392860401388*(h**5) + 22.1396504998094068976*(h**4) - 30.0610568289666450593*(h**3) + 19.9990069333683970610*(h**2) - 6.6910181737837595697*h + 1.0) / (1/h)

C_mid[5] = (- 39.6107919852202505218*(h**5) + 116.4422149550342161651*(h**4) - 121.4999627731334642623*(h**3) + 52.2273532792945524050*(h**2) - 7.6142658045872677172*h) / (1/h)

C_mid[6] = (20.3761213808791436958*(h**5) - 67.1451318825957197185*(h**4) + 83.1721004639847717481*(h**3) - 46.8919164181093621583*(h**2) + 10.7281392630428866124*h) / (1/h)

C_mid[7] = (7.3347098826795362023*(h**5) - 16.5672243527496524646*(h**4) + 9.5724507555993664382*(h**3) - 0.1890893225010595467*(h**2) + 0.5526637063753648783*h) / (1/h)

C_mid[8] = (32.8801774352459155182*(h**5) - 89.9916014847245016028*(h**4) + 87.8406057677205645007*(h**3) - 35.7075975946222072821*(h**2) + 4.2186562625665153803*h) / (1/h)

C_mid[9] = (- 10.1588990526426760954*(h**5) + 22.6237489648532849093*(h**4) - 17.4152107770762969005*(h**3) + 6.2736448083240352160*(h**2) - 0.6627209125361597559*h) / (1/h)

C_mid[10] = (- 12.5401268098782561200*(h**5) + 32.2362340167355370113*(h**4) - 28.5903289514790976966*(h**3) + 10.3160881272450748458*(h**2) - 1.2636789001135462218*h) / (1/h)

C_mid[11] = (29.5553001484516038033*(h**5) - 82.1020315488359848644*(h**4) + 81.6630950584341412934*(h**3) - 34.7650769866611817349*(h**2) + 5.4106037898590422230*h) / (1/h)

C_mid[12] = (- 41.7923486424390588923*(h**5) + 116.2662185791119533462*(h**4) - 114.9375291377009418170*(h**3) + 47.7457971078225540396*(h**2) - 7.0321379067945741781*h) / (1/h)

C_mid[13] = (20.3006925822100825485*(h**5) - 53.9020777466385396792*(h**4) + 50.2558364226176017553*(h**3) - 19.0082099341608028453*(h**2) + 2.3537586759714983486*h) / (1/h)


A = torch.tensor(A, dtype=torch.float64)
B = [torch.tensor(B_, dtype=torch.float64) for B_ in B]
C_sol = torch.tensor(C_sol, dtype=torch.float64)
C_err = torch.tensor(C_err, dtype=torch.float64)
_C_mid = torch.tensor(C_mid, dtype=torch.float64)

_DOPRI8_TABLEAU = _ButcherTableau(alpha=A, beta=B, c_sol=C_sol, c_error=C_err)


class Dopri8Solver(RKAdaptiveStepsizeODESolver):
    order = 8
    tableau = _DOPRI8_TABLEAU
    mid = _C_mid
